# -*- encoding: utf-8 -*-
'''
@File    :   metrics.py
@Time    :   2021/11/22 9:29
@Author  :   ZhangChaoYang
@Desc    :   
'''

import tensorflow as tf
import numpy as np


def relative_error(x, x_hat):
    return tf.reduce_mean(tf.reduce_mean(tf.abs(x - x_hat) / (tf.abs(x) + 1e-4), axis=-1), axis=-1)


if __name__ == '__main__':
    x = np.asarray([[1, 2], [3, 4]])
    x_hat = np.asarray([[5, 6], [7, 8]])
    print(relative_error(x, x_hat))
